import re
import torch
from util import *
import warnings
from transformers.trainer_pt_utils import LabelSmoother
warnings.filterwarnings("ignore")

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

def generate_all_actions(state):
    return_list = []
    if "hand is empty" in state:
        block = re.findall("the [a-z]{0,10} block is clear", state)
        block_color = [re.search("the ([a-z]{0,10}) block is clear", b).group(1) for b in block]
        # print(block_color)
        for c in block_color:
            # print("looking for", c)
            if f"the {c} block is on the table" in state:
                return_list.append(f"Pick up the {c} block")
            else:
                c_ = re.search(f"the {c} block" + " is on top of the ([a-z]{0,10}) block", state).group(1)
                return_list.append(f"Unstack the {c} block from on top of the {c_} block")
    else:
        c = re.search("is holding the ([a-z]{0,10}) block", state).group(1)
        block = re.findall("the [a-z]{0,10} block is clear", state)
        clear_color = [re.search("the ([a-z]{0,10}) block is clear", b).group(1) for b in block]
        for c_ in clear_color:
            return_list.append(f"Stack the {c} block on top of the {c_} block")
        return_list.append(f"Put down the {c} block")
    return return_list

def apply_change(change, state):
    # print("input state:", state)
    if "and the " in state and ", and the" not in state:
        state = state.replace("and the ", ", and the ")
    states = state.split(", ")
    states = [s.strip()[4:].strip(".") if s.strip().startswith("and ") else s.strip().strip(".") for s in states]
    # print("state", states)

    changes = change.lower().strip().strip(".").split(", ")
    # print("initial states:", states)
    for c in changes:
        
        if c.startswith("and "):
            c = c[4:]
        success = 0
        # print("current change", c)
        if c.startswith("the hand"):
            # print(c)
            old = c.split("was")[1].split("and")[0].strip()
            # print(old)
            new = c.split("now")[1].strip()
            # print(new)
            for idx in range(len(states)):
                # print("hand is " + old)
                if ("hand is " + old) in states[idx]:
                    # print(":", s)
                    states[idx] = states[idx].replace(old, new)
                    success += 1
                    # print(s)
        else:
            
            colors = re.findall(r"the (\w+) block", c)
            if len(colors) == 0:
                print("Error: zero-colors")
                print(c)
                torch.distributed.barrier()
                raise Exception("ERROR")
            color = colors[0]
            # print(colors)
            if c.startswith(f"the {color} block"):
                # print("SUB:", f"the {color} block")
                subj = f"{color} block"
                if "no longer" in c:
                    old = c.split("no longer")[1].strip()
                    # print("old:", old)
                    for idx in range(len(states)):
                        if f"{color} block is " + old in states[idx]:
                            states[idx] = ""
                            success += 1
                elif "was" in c and "now" in c:
                    old = c.split("was")[1].split(" and")[0].strip()
                    new = c.split("now")[1].strip()
                    # print("previous:", f"{color} block is " + old)
                    for idx in range(len(states)):
                        if f"{color} block is " + old in states[idx]:
                            states[idx] = states[idx].replace(old, new)
                            success += 1
                elif "now" in c:
                    new = c.split("now")[1].strip()
                    states.append("the " + color + " block is " + new)
                    success += 1
            else:
                # print("ERROR")
                print("Error: not recognized")
                torch.distributed.barrier()
                raise Exception("ERROR")
        if success == 0:
            # print("ERROR")
            # print("Error: no successful change")
            # print(c)
            # print(states)
            return "Infeasible"
            # torch.distributed.barrier()
            # raise Exception("ERROR")
        # print("current states:", states)
    states = [s for s in states if s != ""]
    priority_states = []
    for s in states:
        if "have that" in s:
            priority_states.append(0)
        elif "clear" in s:
            priority_states.append(1)
        elif "in the hand" in s:
            priority_states.append(1)
        elif "the hand is" in s:
            priority_states.append(2)
        elif "on top of" in s:
            priority_states.append(3)
        elif "on the table" in s:
            priority_states.append(4)
        else:
            print("Error: unknown state")
            print(s)
            torch.distributed.barrier()
            raise Exception("ERROR")
    sorted_states = [x.strip() for _, x in sorted(zip(priority_states, states))]
    sorted_states[-1] = "and " + sorted_states[-1]
    return ", ".join(sorted_states) + "."

def query_LM(worldmodel, tokenizer, prompt, eos_token_id, num_return_sequences=1, do_sample=True, temperature=0.7):
    temperature = temperature if do_sample else 0
    all_results = []
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
    results = worldmodel.generate(input_ids, max_new_tokens=100, do_sample=False, top_k=10, eos_token_id=eos_token_id)
    results = tokenizer.decode(results[0], skip_special_tokens=False)
    last_newline_position = results.rfind('\n')

    # If '\n' is found, truncate the decoded output after the last occurrence of '\n', otherwise, keep the whole decoded output
    results = results[:last_newline_position] if last_newline_position != -1 else results
    all_results.append(results)
    return all_results

def preprocess(tokenizer, state, actions):
    # turns = state.split("[ACTION ]")
    # print(turns)
    # state = state[-50:]
    IGNORE_TOKEN_ID = LabelSmoother.ignore_index
    input_ids = torch.tensor(tokenizer.encode(state, return_tensors="pt")).squeeze()
    action_ids = [torch.tensor(tokenizer.encode(action, return_tensors="pt")).squeeze() for action in actions]

    action_positions = []
    
    for action_id in action_ids:
    # 在input_ids中查找与action_id相等的位置
        for i in range(input_ids.shape[-1] - action_id.shape[-1] + 1):
            if torch.equal(input_ids[i:i + action_id.shape[-1]], action_id):
                action_positions.append((i, i + action_id.shape[-1] - 1))
                break
    
    target = [IGNORE_TOKEN_ID] * input_ids.shape[-1]
    # input_ids = list(input_ids)

    for action_position in action_positions:
        target[action_position[0]: action_position[1]] = input_ids[action_position[0]: action_position[1]]

    target = torch.tensor(target)
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long)

    return dict(
        input_ids=input_ids.unsqueeze(0),
        labels=target.unsqueeze(0),
        attention_mask=attention_mask.unsqueeze(0)
    )

def generate_trajectories(initial_state,
                          goal,
                          prompts,
                          model,
                          tokenizer,
                          max_steps,
                          temperature,
                          eos_token_id, 
                          agent=None,
                          
                          ):
    """
    return: trajs, probability of each action in the trajs, log rewards of the trajs, log rewards of (state, action)
    """
    base_prompt = "I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do\n\nPick up a block\nUnstack a block from on top of another block\nPut down a block\nStack a block on top of another block\n\nI have the following restrictions on my actions:\nI can only pick up or unstack one block at a time.\nI can only pick up or unstack a block if my hand is empty.\nI can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up.\nI can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block.\nI can only unstack a block from on top of another block if the block I am unstacking is clear.\nOnce I pick up or unstack a block, I am holding the block.\nI can only put down a block that I am holding.\nI can only stack a block on top of another block if I am holding the block being stacked.\nI can only stack a block on top of another block if the block onto which I am stacking the block is clear.\nOnce I put down or stack a block, my hand becomes empty.\nAfter being given an initial state and an action, give the new state after performing the action.\n"

    state = base_prompt + prompts["goal_prefix"] + goal.strip() + "\n" + prompts["state_prefix"].format(0) + " " + initial_state.strip() + "\n"
    # print(state)

    actions = []

    # print("\n\n\n\n")
    for step in range(max_steps):
        # last_state = re.search(f'.*{re.escape(prompts["state_prefix"].format(step))}(.*)', state)[1]
        
        state += prompts["action_prefix"].format(step + 1) + " "
        input_ids = tokenizer.encode(state, return_tensors='pt').to(model.device)

        agent_output = model.generate(input_ids, max_length=input_ids.shape[-1] + 20, do_sample=True, top_k=10, eos_token_id=eos_token_id)
        decoded_output = tokenizer.decode(agent_output[0][input_ids.shape[-1]:], skip_special_tokens=True)
        newline_position = decoded_output.find('\n')

        action = decoded_output[:newline_position] if newline_position != -1 else decoded_output
       
        state += action

        actions.append(action)

        last_state = re.search(f'.*{re.escape(prompts["state_prefix"].format(step))}(.*)', state)[1]
        last_action = re.search(f'.*{re.escape(prompts["action_prefix"].format(step+1))}(.*)', state)[1]
        # print("last_action: \n", last_action)
        if "Pick" in last_action or "Pick".lower() in last_action: 
            world_update_prompt = prompts["world_update_pickup"].format(last_state, last_action)
        elif "Unstack" in last_action or "Unstack".lower() in last_action:
            world_update_prompt = prompts["world_update_unstack"].format(last_state, last_action)
        elif "Put" in last_action or  "Put".lower() in last_action:
            world_update_prompt = prompts["world_update_putdown"].format(last_state, last_action)
        elif "Stack" in last_action or "Stack".lower() in last_action: 
            world_update_prompt = prompts["world_update_stack"].format(last_state, last_action)
        lora_to_base(model) # transfer to World Model
        world_output = query_LM(model, tokenizer, world_update_prompt, do_sample=False, num_return_sequences=1,
                                    eos_token_id=eos_token_id)[0]

        
        world_change = world_output.split("[CHANGE]")[-1]

        last_state = state.split(f"[STATE {step}]")[-1].split(f"[ACTION {step+1}]")[0]

        # raw_action_list = generate_all_actions(last_state) # ['Unstack the orange block from on top of the blue block', 'Pick up the yellow block']
        
        new_state = apply_change(world_change, last_state)
        
        if new_state == "Infeasible":
            # 
            goal_statement = state.split("[GOAL]")[-1].split("[STATE 0]")[0]
            goals = re.findall("the [a-z]{0,10} block is on top of the [a-z]{0,10} block", goal_statement)
            meetings = [g in last_state for g in goals]
            if sum(meetings) == len(meetings):
                r1 = 100
            else:
                r1 = sum(meetings) / len(meetings) + 0.5

            # Here we could compute the loglikelihood of all the actions with World Model as the reward.
            sample = preprocess(tokenizer, state, actions)
            # with torch.autocast(device_type="cuda"):
            #     outputs = model(input_ids=sample['input_ids'], labels=sample['labels'])

            # # 计算 labels 的 log-likelihood
            # log_likelihood = F.log_softmax(outputs.logits, dim=-1)[torch.arange(sample['labels'].size(0)), sample['labels']].sum()
            r1 = torch.tensor(r1)
            return state, actions, torch.log(r1), sample
        # print("==============new_state================")
        # print("\"" + new_state + "\"")
        state += "\n" + prompts["state_prefix"].format(step+1) + " " + new_state + "\n"

        
        # print(meetings)
        # print(action)
        # transform to environment:


        base_to_lora(model)

        # action_output = [state + prompts["action_prefix"].format(step + 1) + " " + a.capitalize() + ".\n" for a in raw_action_list]
        # last_state = model(last_state, action)
        # step_reward(last_state, action)

    # rewards = reward_fn(trajs)

    # actions = "".join(re.findall(r'\[ACTION \d+\] .*?\n', last_state)).replace(".", "")

    goal_statement = state.split("[GOAL]")[-1].split("[STATE 0]")[0]
    goals = re.findall("the [a-z]{0,10} block is on top of the [a-z]{0,10} block", goal_statement)
    meetings = [g in new_state for g in goals]
    if sum(meetings) == len(meetings):
        r1 = 100
    else:
        r1 = sum(meetings) / len(meetings) + 0.5
    r1 = torch.tensor(r1)
    sample = preprocess(tokenizer, state, actions)

    return state, actions, torch.log(r1), sample